## Read argument
args<-commandArgs(TRUE)
print("begin")
print(args)
print("end")
N_index <- eval( parse(text=args[1]) )

# N_index = 20
library(rstan)

load("GMM/gmm_data.Rdata")
library(rstan)

args <- expand.grid(NR = c(T,F),
                    num_temp = c(3,5,7),
                    Nrep = c(1:20),
                    stoch_time = c(.1,1,2))
N_index = N_rerun[N_index]
rep_1 <- args[N_index,3]
ntemp <- args[N_index,2]
NR <- args[N_index,1]
stoch_time <- args[N_index,4]
print(args[N_index,])

###### Config ####
iters <- 10000   ## Number of samples / number of events
warmup <- 10000  ## Number samples/events to burn
poly_order = 3

set.seed(rep_1)
x_init <- rnorm(2, mean = 5, sd = 5)
theta_init <- c(1,1)


## ct HMC
# stan_fit_eval <- stan("gmm/model_eval.stan", data = dat, iter = 1, chains = 1) #home
stan_fit_eval <- stan("GMM/ct_hmc_eval.stan", data = dat, iter = 1, chains = 1)

## Setup ct-ZZ
target <- function(x){
  stan_ev <- grad_log_prob(stan_fit_eval, c(x,1))
  d_log_q <- as.numeric(stan_ev)[1:2]
  log_q <- attr(stan_ev, "log_prob")

  return(list(log_q = log_q, d_log_q = d_log_q))
}
temper <- function(x){
  stan_ev <- grad_log_prob(stan_fit_eval, c(x,0))
  d_log_q <- as.numeric(stan_ev)[1:2]
  log_q <- attr(stan_ev, "log_prob")

  return(list(log_q = log_q, d_log_q = d_log_q))
}

source("parra_temp.R")
mean_diffs <- (apply(mu_mat, 1, max) - apply(mu_mat, 1, min))^2/4

C_1 <- 1/dat$sigma[1]
hess_q1 <- abs(C_1*(1 + C_1*mean_diffs))
hess_q0 <- 1/dat$sigma0
(hess <- abs(matrix(c(hess_q0,
                      hess_q1), 2,2, byrow = T)))
rownames(hess) <- c("q0", "q1")
##################################################
IT <- 0.1
temp <- c(IT^c(ntemp:1),1)
source("parra_temp.R")

## Burnin
set.seed(rep_1);
res <- pt(xinit = x_init, stoch_time = stoch_time,
          temp = temp, Nit = iters + warmup,
          even_odd_kernel = NR)

save(res, file = paste0("GMM/PT/res_S_",stoch_time,"_NT_",ntemp,"_NR_",NR,"_rep_",rep_1,"_IT_",IT,".Rdata"))
rm(res)
gc()
##################################################
IT <- 0.3
temp <- c(IT^c(ntemp:1),1)
source("parra_temp.R")

## Burnin
set.seed(rep_1);
res <- pt(xinit = x_init, stoch_time = stoch_time,
          temp = temp, Nit = iters + warmup,
          even_odd_kernel = NR)

save(res, file = paste0("GMM/PT/res_S_",stoch_time,"_NT_",ntemp,"_NR_",NR,"_rep_",rep_1,"_IT_",IT,".Rdata"))
rm(res)
gc()
##################################################
IT <- 0.5
temp <- c(IT^c(ntemp:1),1)
source("parra_temp.R")

## Burnin
set.seed(rep_1);
res <- pt(xinit = x_init, stoch_time = stoch_time,
          temp = temp, Nit = iters + warmup,
          even_odd_kernel = NR)

save(res, file = paste0("GMM/PT/res_S_",stoch_time,"_NT_",ntemp,"_NR_",NR,"_rep_",rep_1,"_IT_",IT,".Rdata"))
rm(res)
gc()
##################################################
IT <- 0.7
temp <- c(IT^c(ntemp:1),1)
source("parra_temp.R")

## Burnin
set.seed(rep_1);
res <- pt(xinit = x_init, stoch_time = stoch_time,
          temp = temp, Nit = iters + warmup,
          even_odd_kernel = NR)

save(res, file = paste0("GMM/PT/res_S_",stoch_time,"_NT_",ntemp,"_NR_",NR,"_rep_",rep_1,"_IT_",IT,".Rdata"))


